// Copyright (C) 2024 Kumo inc.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//

#pragma once

#include <cstdint>
#include <kllm/utility/types.h>
#include <kllm/core/lora_adapter.h>
#include <kllm/core/control_vector.h>
#include <kllm/core/kv_override.h>
#include <kllm/core/sampler.h>
#include <kllm/utility/cpu.h>
#include <llama.h>
#include <string>
#include <vector>

#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
namespace kllm {

    struct KMParams {
        int32_t n_predict = -1; // new tokens to predict
        int32_t n_ctx = 4096; // context size
        int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS)
        int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS)
        int32_t n_keep = 0; // number of tokens to keep from initial prompt
        int32_t n_draft = 5; // number of tokens to draft during speculative decoding
        int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
        int32_t n_parallel = 1; // number of parallel sequences to decode
        int32_t n_sequences = 1; // number of sequences to decode
        float p_split = 0.1f; // speculative decoding split probability
        int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
        int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
        int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
        float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
        int32_t grp_attn_n = 1; // group-attention factor
        int32_t grp_attn_w = 512; // group-attention width
        int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
        float rope_freq_base = 0.0f; // RoPE base frequency
        float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
        float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
        float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
        float yarn_beta_fast = 32.0f; // YaRN low correction dim
        float yarn_beta_slow = 1.0f; // YaRN high correction dim
        int32_t yarn_orig_ctx = 0; // YaRN original context length
        float defrag_thold = -1.0f; // KV cache defragmentation threshold

        struct cpu_params cpuparams;
        struct cpu_params cpuparams_batch;
        struct cpu_params draft_cpuparams;
        struct cpu_params draft_cpuparams_batch;

        ggml_backend_sched_eval_callback cb_eval = nullptr;
        void *cb_eval_user_data = nullptr;

        ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

        enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
        enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
        enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
        enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings

        struct InternalSamplerParams sparams;

        std::string model = ""; // model path                                                    // NOLINT
        std::string model_draft = ""; // draft model for speculative decoding                          // NOLINT
        std::string model_alias = "unknown"; // model alias                                            // NOLINT
        std::string model_url = ""; // model url to download                                         // NOLINT
        std::string hf_token = ""; // HF token                                                      // NOLINT
        std::string hf_repo = ""; // HF repo                                                       // NOLINT
        std::string hf_file = ""; // HF file                                                       // NOLINT
        std::string prompt = "";                                                                  // NOLINT
        std::string prompt_file = ""; // store the external prompt file name                           // NOLINT
        std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state             // NOLINT
        std::string input_prefix = ""; // string to prefix user inputs with                             // NOLINT
        std::string input_suffix = ""; // string to suffix user inputs with                             // NOLINT
        std::string logdir = ""; // directory in which to save YAML log files                     // NOLINT
        std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding           // NOLINT
        std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding          // NOLINT
        std::string logits_file = ""; // file for saving *all* logits                                  // NOLINT
        std::string rpc_servers = ""; // comma separated list of RPC servers                           // NOLINT

        std::vector<std::string> in_files;   // all input files
        std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
        std::vector<llama_model_kv_override> kv_overrides;

        bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
        std::vector<common_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale

        std::vector<ControlVectorLoadInfo> control_vectors; // control vector with user defined scale

        int32_t verbosity = 0;
        int32_t control_vector_layer_start = -1; // layer range for control vector
        int32_t control_vector_layer_end = -1; // layer range for control vector

        int32_t ppl_stride = 0;     // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
        int32_t ppl_output_type = 0;     // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
        //                                       (which is more convenient to use for plotting)
        //
        bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
        size_t hellaswag_tasks = 400;   // number of tasks to use when computing the HellaSwag score

        bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt
        size_t winogrande_tasks = 0;     // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed

        bool multiple_choice = false;  // compute TruthfulQA score over random tasks from datafile supplied in prompt
        size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed

        bool kl_divergence = false; // compute KL divergence

        bool usage = false; // print usage
        bool use_color = false; // use color to distinguish generations and inputs
        bool special = false; // enable special token output
        bool interactive = false; // interactive mode
        bool interactive_first = false; // wait for user input immediately
        bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
        bool prompt_cache_all = false; // save user input and generations to prompt cache
        bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it

        bool escape = true;  // escape "\n", "\r", "\t", "\'", "\"", and "\\"
        bool multiline_input = false; // reverse the usage of `\`
        bool simple_io = false; // improves compatibility with subprocesses and limited consoles
        bool cont_batching = true;  // insert new sequences for decoding on-the-fly
        bool flash_attn = false; // flash attention
        bool no_perf = false; // disable performance metrics
        bool ctx_shift = true;  // context shift on inifinite text generation

        bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
        bool logits_all = false; // return logits for all tokens in the batch
        bool use_mmap = true;  // use mmap for faster loads
        bool use_mlock = false; // use mlock to keep model in memory
        bool verbose_prompt = false; // print prompt tokens before generation
        bool display_prompt = true;  // print prompt before generation
        bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
        bool no_kv_offload = false; // disable KV offloading
        bool warmup = true;  // warmup run
        bool check_tensors = false; // validate tensor data

        std::string cache_type_k = "f16"; // KV cache data type for the K
        std::string cache_type_v = "f16"; // KV cache data type for the V

        // multimodal models (see examples/llava)
        std::string mmproj = "";        // path to multimodal projector                                         // NOLINT
        std::vector<std::string> image; // path to image file(s)

        // embedding
        bool embedding = false; // get only sentence embedding
        int32_t embd_normalize = 2;     // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
        std::string embd_out = "";    // empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
        std::string embd_sep = "\n";  // separator of embeddings
        bool reranking = false; // enable reranking support on server

        // server params
        int32_t port = 8080;         // server listens on this network port
        int32_t timeout_read = 600;          // http read timeout in seconds
        int32_t timeout_write = timeout_read; // http write timeout in seconds
        int32_t n_threads_http = -1;           // number of threads to process HTTP requests (TODO: support threadpool)
        int32_t n_cache_reuse = 0;            // min chunk size to reuse from the cache via KV shifting

        std::string hostname = "127.0.0.1";
        std::string public_path = "";                                                                         // NOLINT
        std::string chat_template = "";                                                                         // NOLINT
        bool enable_chat_template = true;

        std::vector<std::string> api_keys;

        std::string ssl_file_key = "";                                                                         // NOLINT
        std::string ssl_file_cert = "";                                                                         // NOLINT

        // "advanced" endpoints are disabled by default for better security
        bool webui = true;
        bool endpoint_slots = false;
        bool endpoint_props = false; // only control POST requests, not GET
        bool endpoint_metrics = false;

        bool log_json = false;

        std::string slot_save_path;

        float slot_prompt_similarity = 0.5f;

        // batched-bench params
        bool is_pp_shared = false;

        std::vector<int32_t> n_pp;
        std::vector<int32_t> n_tg;
        std::vector<int32_t> n_pl;

        // retrieval params
        std::vector<std::string> context_files; // context files to embed

        int32_t chunk_size = 64; // chunk size for context embedding

        std::string chunk_separator = "\n"; // chunk separator for context embedding

        // passkey params
        int32_t n_junk = 250; // number of times to repeat the junk text
        int32_t i_pos = -1;  // position of the passkey in the junk text

        // imatrix params
        std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file

        int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
        int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
        int32_t i_chunk = 0; // start processing from this chunk

        bool process_output = false; // collect data for the output tensor
        bool compute_ppl = true;  // whether to compute perplexity

        // cvector-generator params
        int n_pca_batch = 100;
        int n_pca_iterations = 1000;
        dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
        std::string cvector_outfile = "control_vector.gguf";
        std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
        std::string cvector_negative_file = "examples/cvector-generator/negative.txt";

        bool spm_infill = false; // suffix/prefix/middle pattern for infill

        std::string lora_outfile = "ggml-lora-merged-f16.gguf";

        // batched-bench params
        bool batched_bench_output_jsonl = false;

        void verify_model_alias();
    };

    std::string common_params_get_system_info(const KMParams &params);
}  // namespace kllm
